{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Batch Normalization in `gluon`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from __future__ import print_function\n",
    "import mxnet as mx\n",
    "from mxnet import nd, autograd\n",
    "from mxnet import gluon\n",
    "import numpy as np\n",
    "mx.random.seed(1)\n",
    "ctx = mx.cpu()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## The MNIST dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "batch_size = 64\n",
    "num_inputs = 784\n",
    "num_outputs = 10\n",
    "def transform(data, label):\n",
    "    return nd.transpose(data.astype(np.float32), (2,0,1))/255, label.astype(np.float32)\n",
    "train_data = mx.gluon.data.DataLoader(mx.gluon.data.vision.MNIST(train=True, transform=transform),\n",
    "                                      batch_size, shuffle=True)\n",
    "test_data = mx.gluon.data.DataLoader(mx.gluon.data.vision.MNIST(train=False, transform=transform),\n",
    "                                     batch_size, shuffle=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define a CNN with Batch Normalization\n",
    "\n",
    "We only need to add a few lines to the model. Pay attention that we insert `BatchNorm` layers before activation layers in this example."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "num_fc = 512\n",
    "net = gluon.nn.Sequential()\n",
    "with net.name_scope():\n",
    "    net.add(gluon.nn.Conv2D(channels=20, kernel_size=5))\n",
    "    net.add(gluon.nn.BatchNorm(axis=1, center=True, scale=True))\n",
    "    net.add(gluon.nn.Activation(activation='relu'))\n",
    "    net.add(gluon.nn.MaxPool2D(pool_size=2, strides=2))\n",
    "    \n",
    "    net.add(gluon.nn.Conv2D(channels=50, kernel_size=5))\n",
    "    net.add(gluon.nn.BatchNorm(axis=1, center=True, scale=True))\n",
    "    net.add(gluon.nn.Activation(activation='relu'))\n",
    "    net.add(gluon.nn.MaxPool2D(pool_size=2, strides=2))\n",
    "    \n",
    "    # The Flatten layer collapses all axis, except the first one, into one axis.\n",
    "    net.add(gluon.nn.Flatten())\n",
    "    \n",
    "    net.add(gluon.nn.Dense(num_fc))\n",
    "    net.add(gluon.nn.BatchNorm(axis=1, center=True, scale=True))\n",
    "    net.add(gluon.nn.Activation(activation='relu'))\n",
    "    \n",
    "    net.add(gluon.nn.Dense(num_outputs))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Parameter initialization\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "net.collect_params().initialize(mx.init.Xavier(magnitude=2.24), ctx=ctx)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Softmax cross-entropy Loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "softmax_cross_entropy = gluon.loss.SoftmaxCrossEntropyLoss()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Optimizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': .1})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Write evaluation loop to calculate accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def evaluate_accuracy(data_iterator, net):\n",
    "    acc = mx.metric.Accuracy()\n",
    "    for i, (data, label) in enumerate(data_iterator):\n",
    "        data = data.as_in_context(ctx)\n",
    "        label = label.as_in_context(ctx)\n",
    "        output = net(data)\n",
    "        predictions = nd.argmax(output, axis=1)\n",
    "        acc.update(preds=predictions, labels=label)\n",
    "    return acc.get()[1]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training Loop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0. Loss: 0.0512200891564, Train_acc 0.992266666667, Test_acc 0.99\n",
      "Epoch 1. Loss: 0.0295315987395, Train_acc 0.995083333333, Test_acc 0.9915\n",
      "Epoch 2. Loss: 0.0193948982644, Train_acc 0.996583333333, Test_acc 0.9926\n",
      "Epoch 3. Loss: 0.0159150274969, Train_acc 0.997766666667, Test_acc 0.9919\n",
      "Epoch 4. Loss: 0.00999366554562, Train_acc 0.997933333333, Test_acc 0.9927\n",
      "Epoch 5. Loss: 0.00839924895731, Train_acc 0.9992, Test_acc 0.9928\n",
      "Epoch 6. Loss: 0.00596852778984, Train_acc 0.9997, Test_acc 0.9938\n",
      "Epoch 7. Loss: 0.00484333098223, Train_acc 0.9977, Test_acc 0.9912\n",
      "Epoch 8. Loss: 0.0048461284779, Train_acc 0.999533333333, Test_acc 0.9928\n",
      "Epoch 9. Loss: 0.00329836090038, Train_acc 0.999433333333, Test_acc 0.9924\n"
     ]
    }
   ],
   "source": [
    "epochs = 10\n",
    "smoothing_constant = .01\n",
    "\n",
    "for e in range(epochs):\n",
    "    for i, (data, label) in enumerate(train_data):\n",
    "        data = data.as_in_context(ctx)\n",
    "        label = label.as_in_context(ctx)\n",
    "        with autograd.record():\n",
    "            output = net(data)\n",
    "            loss = softmax_cross_entropy(output, label)\n",
    "        loss.backward()\n",
    "        trainer.step(data.shape[0])\n",
    "        \n",
    "        ##########################\n",
    "        #  Keep a moving average of the losses\n",
    "        ##########################\n",
    "        curr_loss = nd.mean(loss).asscalar()\n",
    "        moving_loss = (curr_loss if ((i == 0) and (e == 0)) \n",
    "                       else (1 - smoothing_constant) * moving_loss + (smoothing_constant) * curr_loss)\n",
    "        \n",
    "    test_accuracy = evaluate_accuracy(test_data, net)\n",
    "    train_accuracy = evaluate_accuracy(train_data, net)\n",
    "    print(\"Epoch %s. Loss: %s, Train_acc %s, Test_acc %s\" % (e, moving_loss, train_accuracy, test_accuracy))    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Conclusion\n",
    "\n",
    "First of all, this version is way faster than our scratch version (on cpu and gpu), because `mxnet.gluon` has significant optimizations in C++.\n",
    "\n",
    "Secondly, compared with the pure CNN version, the CNN model here with Batch Normalization has higher accuracy even just in the first few epochs! We are absolutely benefiting from the magic power of Batch Normalization."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Next\n",
    "[Introduction to recurrent neural networks](../chapter05_recurrent-neural-networks/simple-rnn.ipynb)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For whinges or inquiries, [open an issue on  GitHub.](https://github.com/zackchase/mxnet-the-straight-dope)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.4.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
